{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0vUhMqkcDnTK"
      },
      "source": [
        "# Gradient Based Constraint Learning Demo\n",
        "\n",
        "Licensed under the Apache License, Version 2.0.\n",
        "\n",
        "This colab explores joint learning neural networks with soft constraints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jPaQDSFDDlfG"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import random\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bq7SE1IAvBAP"
      },
      "source": [
        "# Dataset and Task\n",
        "\n",
        "We test and validate our system over a common fairness dataset and task: [Adult Census Income dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income). This data was extracted from the [1994 Census bureau database](http://www.census.gov/en.html) by Ronny Kohavi and Barry Becker. Our analysis aims at learning a model that does not bias predictions towards men over 50K through soft constraints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t5dqEvbHu1s5"
      },
      "outputs": [],
      "source": [
        "# ========================================================================\n",
        "# Constants\n",
        "# ========================================================================\n",
        "_TRAIN_PATH = ''\n",
        "_TEST_PATH = ''\n",
        "\n",
        "_COLUMNS = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education_num\",\n",
        "           \"marital_status\", \"occupation\", \"relationship\", \"race\", \"gender\",\n",
        "           \"capital_gain\", \"capital_loss\", \"hours_per_week\", \"native_country\",\n",
        "           \"income_bracket\"]\n",
        "\n",
        "# ========================================================================\n",
        "# Seed Data\n",
        "# ========================================================================\n",
        "SEED = random.randint(-10000000, 10000000)\n",
        "print(\"Seed: %d\" % SEED)\n",
        "tf.random.set_seed(SEED)\n",
        "\n",
        "# ========================================================================\n",
        "# Load Data\n",
        "# ========================================================================\n",
        "with tf.io.gfile.GFile(_TRAIN_PATH, 'r') as csv_file:\n",
        "  train_df = pd.read_csv(csv_file, names=_COLUMNS, sep=r'\\s*,\\s*', na_values=\"?\").dropna(how=\"any\", axis=0)\n",
        "\n",
        "with tf.io.gfile.GFile(_TEST_PATH, 'r') as csv_file:\n",
        "  test_df = pd.read_csv(csv_file, names=_COLUMNS, skiprows=[0], sep=r'\\s*,\\s*', na_values=\"?\").dropna(how=\"any\", axis=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S8U7bmZqFPUv"
      },
      "source": [
        "# Feature Columns\n",
        "\n",
        "The following code was taken from [intro_to_fairness](https://colab.sandbox.google.com/notebooks/mlcc/intro_to_fairness.ipynb#scrollTo=tAG5hUJwx725). In short, Tensorflow requires a mapping of data and so every column is specified."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "hOGqN4iKEsJp"
      },
      "outputs": [],
      "source": [
        "#@title Prepare Dataset\n",
        "# ========================================================================\n",
        "# Categorical Feature Columns\n",
        "# ========================================================================\n",
        "# Unknown length\n",
        "occupation = tf.feature_column.categorical_column_with_hash_bucket(\n",
        "    \"occupation\", hash_bucket_size=1000)\n",
        "native_country = tf.feature_column.categorical_column_with_hash_bucket(\n",
        "    \"native_country\", hash_bucket_size=1000)\n",
        "\n",
        "# Known length\n",
        "gender = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "    \"gender\", [\"Female\", \"Male\"])\n",
        "race = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "    \"race\", [\n",
        "        \"White\", \"Asian-Pac-Islander\", \"Amer-Indian-Eskimo\", \"Other\", \"Black\"\n",
        "    ])\n",
        "education = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "    \"education\", [\n",
        "        \"Bachelors\", \"HS-grad\", \"11th\", \"Masters\", \"9th\",\n",
        "        \"Some-college\", \"Assoc-acdm\", \"Assoc-voc\", \"7th-8th\",\n",
        "        \"Doctorate\", \"Prof-school\", \"5th-6th\", \"10th\", \"1st-4th\",\n",
        "        \"Preschool\", \"12th\"\n",
        "    ])\n",
        "marital_status = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "    \"marital_status\", [\n",
        "        \"Married-civ-spouse\", \"Divorced\", \"Married-spouse-absent\",\n",
        "        \"Never-married\", \"Separated\", \"Married-AF-spouse\", \"Widowed\"\n",
        "    ])\n",
        "relationship = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "    \"relationship\", [\n",
        "        \"Husband\", \"Not-in-family\", \"Wife\", \"Own-child\", \"Unmarried\",\n",
        "        \"Other-relative\"\n",
        "    ])\n",
        "workclass = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "    \"workclass\", [\n",
        "        \"Self-emp-not-inc\", \"Private\", \"State-gov\", \"Federal-gov\",\n",
        "        \"Local-gov\", \"?\", \"Self-emp-inc\", \"Without-pay\", \"Never-worked\"\n",
        "    ])\n",
        "\n",
        "# ========================================================================\n",
        "# Numeric Feature Columns\n",
        "# ========================================================================\n",
        "age = tf.feature_column.numeric_column(\"age\")\n",
        "age_buckets = tf.feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n",
        "fnlwgt = tf.feature_column.numeric_column(\"fnlwgt\")\n",
        "education_num = tf.feature_column.numeric_column(\"education_num\")\n",
        "capital_gain = tf.feature_column.numeric_column(\"capital_gain\")\n",
        "capital_loss = tf.feature_column.numeric_column(\"capital_loss\")\n",
        "hours_per_week = tf.feature_column.numeric_column(\"hours_per_week\")\n",
        "\n",
        "# ========================================================================\n",
        "# Specify Features\n",
        "# ========================================================================\n",
        "deep_columns = [\n",
        "    tf.feature_column.indicator_column(workclass),\n",
        "    tf.feature_column.indicator_column(education),\n",
        "    tf.feature_column.indicator_column(age_buckets),\n",
        "    tf.feature_column.indicator_column(gender),\n",
        "    tf.feature_column.indicator_column(relationship),\n",
        "    tf.feature_column.embedding_column(native_country, dimension=8),\n",
        "    tf.feature_column.embedding_column(occupation, dimension=8),\n",
        "]\n",
        "\n",
        "features = {\n",
        "  'age': tf.keras.Input(shape=(1,), name='age'),\n",
        "  'education': tf.keras.Input(shape=(1,), name='education', dtype=tf.string),\n",
        "  'gender': tf.keras.Input(shape=(1,), name='gender', dtype=tf.string),\n",
        "  'native_country': tf.keras.Input(shape=(1,), name='native_country', dtype=tf.string),\n",
        "  'occupation': tf.keras.Input(shape=(1,), name='occupation', dtype=tf.string),\n",
        "  'relationship': tf.keras.Input(shape=(1,), name='relationship', dtype=tf.string),\n",
        "  'workclass': tf.keras.Input(shape=(1,), name='workclass', dtype=tf.string),\n",
        "}\n",
        "\n",
        "# ========================================================================\n",
        "# Create Dataset\n",
        "# ========================================================================\n",
        "def df_to_dataset(dataframe, shuffle=True, batch_size=512):\n",
        "  dataframe = dataframe.copy()\n",
        "  labels = dataframe.pop('income_bracket').apply(lambda x: \"\u003e50K\" in x).astype(int)\n",
        "  ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))\n",
        "  if shuffle:\n",
        "    ds = ds.shuffle(buffer_size=len(dataframe))\n",
        "  ds = ds.batch(batch_size)\n",
        "  return ds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "7SCi3pLhBANM"
      },
      "outputs": [],
      "source": [
        "#@title Helper Functions\n",
        "def confusion_matrix(predictions, labels, threshold=0.5):\n",
        "  tp = 0\n",
        "  tn = 0\n",
        "  fp = 0\n",
        "  fn = 0\n",
        "  for prediction, label in zip(predictions, labels):\n",
        "    if prediction \u003e threshold:\n",
        "      if label == 1:\n",
        "        tp += 1\n",
        "      else:\n",
        "        fp += 1\n",
        "    else:\n",
        "      if label == 0:\n",
        "        tn += 1\n",
        "      else:\n",
        "        fn += 1\n",
        "  return (tp, tn, fp, fn)\n",
        "\n",
        "def remove_group(dataframe, predictions, group):\n",
        "  dataframe = dataframe.copy()\n",
        "  dataframe['predictions'] = predictions\n",
        "  dataframe = dataframe[dataframe.gender != group]\n",
        "\n",
        "  group_predictions = dataframe.pop('predictions')\n",
        "\n",
        "  return dataframe, group_predictions\n",
        "\n",
        "def print_accuracy(dataframe, predictions, threshold=0.5):\n",
        "  dataframe = dataframe.copy()\n",
        "  labels = dataframe.pop('income_bracket').apply(lambda x: \"\u003e50K\" in x).astype(int)\n",
        "\n",
        "  tp, tn, fp, fn = confusion_matrix(predictions, labels, threshold=threshold)\n",
        "  print(\"True Positives: %d True Negatives: %d False Positives %d False Negatives: %d\" % (tp, tn, fp, fn))\n",
        "  print(\"Accuracy: %0.5f\" % ((tp+tn) / (tp + tn + fp + fn)))\n",
        "  print(\"Positive Accuracy: %0.5f\" % (tp / (tp + fp)))\n",
        "  print(\"Negative Accuracy: %0.5f\" % (tn / (tn + fn)))\n",
        "  print(\"Percentage Predicted over \u003e50K: %0.5f\" % (((tp + fp) / (tp + tn + fp + fn)) * 100))\n",
        "\n",
        "  return (tp, tn, fp, fn)\n",
        "\n",
        "def parity(m_tp, m_fp, m_tn, m_fn, f_tp, f_fp, f_tn, f_fn):\n",
        "  return ((m_tp + m_fp) / (m_tp + m_tn + m_fp + m_fn)) - ((f_tp + f_fp) / (f_tp + f_tn + f_fp + f_fn))\n",
        "\n",
        "def print_title(title, print_length=50):\n",
        "  print(('-' * print_length) + '\\n' + title + '\\n' + ('-' * print_length))\n",
        "\n",
        "def print_analysis(train_df, train_predictions, test_df, test_predictions):\n",
        "  print_title(\"Train Accuracy\")\n",
        "  print_accuracy(train_df, train_predictions)\n",
        "\n",
        "  print_title(\"Full Test Accuracy\")\n",
        "  print_accuracy(test_df, test_predictions)\n",
        "\n",
        "  print_title(\"Male Test Accuracy\")\n",
        "  male_df, male_pred = remove_group(test_df, test_predictions, \"Female\")\n",
        "  m_tp, m_tn, m_fp, m_fn = print_accuracy(male_df, male_pred)\n",
        "\n",
        "  print_title(\"Female Test Accuracy\")\n",
        "  female_df, female_pred = remove_group(test_df, test_predictions, \"Male\")\n",
        "  f_tp, f_tn, f_fp, f_fn = print_accuracy(female_df, female_pred)\n",
        "\n",
        "  print_title(\"Parity\")\n",
        "  print(parity(m_tp, m_fp, m_tn, m_fn, f_tp, f_fp, f_tn, f_fn))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G8brN9Ppx3--"
      },
      "source": [
        "# Create and Run Non-Constrained Neural Model\n",
        "\n",
        "Defining our neural model that will be used as a comparison. Note: this model was purposfully designed to be simplistic, as it is trying to highlight the benifit to learning with soft constraints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1k0_BxR9Ls9X"
      },
      "outputs": [],
      "source": [
        "def build_model(feature_columns, features):\n",
        "  feature_layer = tf.keras.layers.DenseFeatures(feature_columns)\n",
        "  hidden_layer_1 = tf.keras.layers.Dense(1024, activation='relu')(feature_layer(features))\n",
        "  hidden_layer_2 = tf.keras.layers.Dense(512, activation='relu')(hidden_layer_1)\n",
        "  output = tf.keras.layers.Dense(1, activation='sigmoid')(hidden_layer_2)\n",
        "\n",
        "  model = tf.keras.Model([v for v in features.values()], output)\n",
        "\n",
        "  model.compile(optimizer='adam',\n",
        "                loss='mse',\n",
        "                metrics=['accuracy'])\n",
        "\n",
        "  return model\n",
        "\n",
        "baseline_model = build_model(deep_columns, features)\n",
        "baseline_model.fit(df_to_dataset(train_df), epochs=50)\n",
        "\n",
        "test_predictions = baseline_model.predict(df_to_dataset(test_df, shuffle=False))\n",
        "baseline_model.evaluate(df_to_dataset(test_df))\n",
        "\n",
        "train_predictions = baseline_model.predict(df_to_dataset(train_df, shuffle=False))\n",
        "baseline_model.evaluate(df_to_dataset(train_df))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ji1jtzBvyLTo"
      },
      "source": [
        "# Analyze Non-Constrained Results\n",
        "\n",
        "For this example we look at the fairness constraint that the protected group (gender) should have no predictive difference between classes. In this situation this means that the ratio of positive predictions should be the same between male and female.\n",
        "\n",
        "Note: this is by no means the only fairness constraint needed to have a fair model, and in fact can result in some doubious results (as seen in the follwoing section).\n",
        "\n",
        "The results do clearly show a skew in ratios as males are have a higher ratio of \u003e50k predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wWQJpfMryH1l"
      },
      "outputs": [],
      "source": [
        "print_analysis(train_df, train_predictions, test_df, test_predictions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xpFJjNOfyurV"
      },
      "source": [
        "# Define Constraints\n",
        "\n",
        "This requires a constrained loss function and a custom train step within the keras model class."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6d3jk0Pu8hTP"
      },
      "outputs": [],
      "source": [
        "def constrained_loss(data, logits, threshold=0.5, weight=3):\n",
        "  \"\"\"Linear constrained loss for equal ratio prediction for the protected group.\n",
        "\n",
        "  The constraint: (#Female \u003e50k / #Total Female) - (#Male \u003e50k / #Total Male)\n",
        "  This constraint penalizes predictions between the protected group (gender),\n",
        "  such that the ratio between all classes must be the same.\n",
        "  An important note: to maintian differentability we do not use #Female \u003e50k\n",
        "  (which requires a round operation), instead we set values below the threshold\n",
        "  to zero, and sum the logits.\n",
        "\n",
        "  Args:\n",
        "  data: Input features.\n",
        "  logits: Predictions made in the logit.\n",
        "  threshold: Binary threshold for predicting positive and negative labels.\n",
        "  weight: Weight of the constrained loss.\n",
        "\n",
        "  Returns:\n",
        "  A scalar loss of the constraint violations.\n",
        "  \"\"\"\n",
        "  gender_label, gender_idx, gender_count = tf.unique_with_counts(data['gender'], out_idx=tf.int32, name=None)\n",
        "  cut_logits = tf.reshape(tf.cast(logits \u003e threshold, logits.dtype) * logits, [-1])\n",
        "\n",
        "  def f1():\n",
        "    return gender_idx\n",
        "  def f2():\n",
        "    return tf.cast(tf.math.logical_not(tf.cast(gender_idx, tf.bool)), tf.int32)\n",
        "\n",
        "  # Load male indexes as ones and female indexes to zeros.\n",
        "  male_index = tf.cond(tf.reduce_all(tf.equal(gender_label, tf.constant([\"Male\", \"Female\"]))), f1, f2)\n",
        "  # Cast the integers to float32 to do a multiplication with the logits.\n",
        "  male_index = tf.cast(male_index, tf.float32)\n",
        "  # (#Male \u003e 50k / #Total Male)\n",
        "  male_prob = tf.divide(tf.reduce_sum(tf.multiply(cut_logits, male_index)), tf.reduce_sum(male_index))\n",
        "\n",
        "  # Flip all female indexes to one and male indexes to zeros.\n",
        "  female_index = tf.math.logical_not(tf.cast(male_index, tf.bool))\n",
        "  # Cast the integers to float32 to do a multiplication with the logits.\n",
        "  female_index = tf.cast(female_index, tf.float32)\n",
        "  # (#Female \u003e 50k / #Total Female)\n",
        "  female_prob = tf.divide(tf.reduce_sum(tf.multiply(cut_logits, female_index)), tf.reduce_sum(female_index))\n",
        "\n",
        "  # Since tf.math.abs is not differentable, separate the loss into two hinges.\n",
        "  loss = tf.add(tf.maximum(male_prob - female_prob, 0.0), tf.maximum(female_prob - male_prob, 0.0))\n",
        "  return tf.multiply(loss, weight)\n",
        "\n",
        "class StructureModel(keras.Model):\n",
        "  def train_step(self, data):\n",
        "    features, labels = data\n",
        "\n",
        "    with tf.GradientTape() as tape:\n",
        "      logits = self(features, training=True)\n",
        "      standard_loss = self.compiled_loss(labels, logits, regularization_losses=self.losses)\n",
        "      constraint_loss = constrained_loss(features, logits)\n",
        "      loss =  standard_loss + constraint_loss\n",
        "\n",
        "    trainable_vars = self.trainable_variables\n",
        "    gradients = tape.gradient(loss, trainable_vars)\n",
        "\n",
        "    self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "    self.compiled_metrics.update_state(labels, logits)\n",
        "\n",
        "    return {m.name: m.result() for m in self.metrics}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PoZCYYTwzidW"
      },
      "source": [
        "# Build and Run Constrained Neural Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rJ3-B8Bjzl30"
      },
      "outputs": [],
      "source": [
        "def build_constrained_model(feature_columns, features):\n",
        "  feature_layer = tf.keras.layers.DenseFeatures(feature_columns)\n",
        "  hidden_layer_1 = tf.keras.layers.Dense(1024, activation='relu')(feature_layer(features))\n",
        "  hidden_layer_2 = tf.keras.layers.Dense(512, activation='relu')(hidden_layer_1)\n",
        "  output = tf.keras.layers.Dense(1, activation='sigmoid')(hidden_layer_2)\n",
        "\n",
        "  model = StructureModel([v for v in features.values()], output)\n",
        "\n",
        "  model.compile(optimizer='adam',\n",
        "                loss='mse',\n",
        "                metrics=['accuracy'])\n",
        "\n",
        "  return model\n",
        "\n",
        "constrained_model = build_constrained_model(deep_columns, features)\n",
        "constrained_model.fit(df_to_dataset(train_df), epochs=50)\n",
        "\n",
        "test_predictions = constrained_model.predict(df_to_dataset(test_df, shuffle=False))\n",
        "constrained_model.evaluate(df_to_dataset(test_df))\n",
        "\n",
        "train_predictions = constrained_model.predict(df_to_dataset(train_df, shuffle=False))\n",
        "constrained_model.evaluate(df_to_dataset(train_df))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LhvI8j62ltWt"
      },
      "source": [
        "# Analyze Constrained Results\n",
        "\n",
        "Ideally this constraint should correct the ratio imbalance between the protected groups (gender). This means our parity should be very close to zero.\n",
        "\n",
        "Note: This constraint does not mean the neural classifier is guaranteed to generalize and make better predictions. It is more likely to attempt to balance the class prediction ratio in the simplest fashion (resulting in a worse accuracy)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oQrWoCYmlvbA"
      },
      "outputs": [],
      "source": [
        "print_analysis(train_df, train_predictions, test_df, test_predictions)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "gradient_based_constraint_learning_demo.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1ri2Frd6c1Ho7lC7_b91X7q6onP5Yn1uP",
          "timestamp": 1633637940232
        },
        {
          "file_id": "/piper/depot/google3/experimental/users/cfpryor/neural_learning_with_soft_constraints.ipynb",
          "timestamp": 1627691047339
        },
        {
          "file_id": "/piper/depot/google3/experimental/users/cfpryor/neural_learning_with_soft_constraints.ipynb?workspaceId=cfpryor:cfpryor_intern2021::citc",
          "timestamp": 1626731982182
        },
        {
          "file_id": "1s_gTqW9Oy9G1pU7VrtIUNlL3SUJcGAZP",
          "timestamp": 1626478300089
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
